import torch
import torch.nn as nn


class SquarePathIntegral(nn.Module):
    def __init__(self, scaler, n_q, ntimes=1):
        super(SquarePathIntegral, self).__init__()
        self.lambda_copies = nn.Parameter(torch.randn(n_q * ntimes, 1, 1) / scaler)
        self.n_q = n_q
        self.scaler = scaler
        self.in_subsystem = None
        self.out_subsystem = None
        self.out_subsystem_sum = None


    def _path_integral(self, in_subsystem=None, out_subsystem=None):
        if in_subsystem == None:
            in_subsystem = self.in_subsystem
        if out_subsystem == None:
            out_subsystem = self.out_subsystem
        if in_subsystem == None or out_subsystem == None:
            raise ValueError("Path integral requires computational object.")
        
        if in_subsystem.shape[-3] == 1 and out_subsystem.shape[-3] == self.n_q:
            out_subsystem = out_subsystem * self.lambda_copies
            out_subsystem = out_subsystem.unflatten(-3, (self.n_q, -1)).flatten(-3, -2)
            out_subsystem_sum = out_subsystem.sum(-3)
            self.out_subsystem_sum = out_subsystem_sum.unsqueeze(-3)
            weighted_dist_sum = torch.matmul(in_subsystem.squeeze(-3), 
                                             out_subsystem_sum.transpose(-2, -1))
        elif in_subsystem.shape[-3] == self.n_q and out_subsystem.shape[-3] == 1:
            in_subsystem = in_subsystem * self.lambda_copies
            in_subsystem = in_subsystem.unflatten(-3, (self.n_q, -1)).flatten(-3, -2)
            in_subsystem_sum = in_subsystem.sum(-3)
            weighted_dist_sum = torch.matmul(in_subsystem_sum, 
                                             out_subsystem.squeeze(-3).transpose(-2, -1))
        else:
            dist = torch.matmul(in_subsystem, out_subsystem.transpose(-2, -1))
            weighted_dist = dist * self.lambda_copies
            weighted_dist = weighted_dist.unflatten(-3, (self.n_q, -1)).flatten(-3, -2)
            weighted_dist_sum = weighted_dist.sum(-3)  # (n_in, n_out)

        return weighted_dist_sum


    def forward(self, in_subsystem, out_subsystem):
        return self._path_integral(in_subsystem, out_subsystem)
    


class ClassCommunicator(nn.Module):
    def __init__(self, d_in, q_dim, n_q, dropout):
        super(ClassCommunicator, self).__init__()
        self.q_dim = q_dim
        self.pnode_agg = SquarePathIntegral(q_dim, n_q)
        self.glob2disp = nn.Sequential(nn.Linear(d_in, q_dim), 
                                       nn.LeakyReLU(), 
                                       nn.Dropout(dropout),
                                       nn.Linear(q_dim, q_dim), 
                                       nn.LeakyReLU(), 
                                       nn.Dropout(dropout),
                                       nn.Linear(q_dim, q_dim), )


    def forward(self, state, glob):
        glob_updater = self.pnode_agg(state, state)  # (n_pnode, n_pnode)
        glob_update = torch.matmul(glob_updater, glob)  # (n_pnode, d_in)

        displacement = self.glob2disp(glob_update)  # (n_pnode, q_dim)
        if displacement.ndim < 4:
            displacement = displacement.unsqueeze(-3)

        return displacement
